{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原文代码作者：https://blog.csdn.net/tudaodiaozhale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HiddenMarkov:\n",
    "    def forward(self, Q, V, A, B, O, PI):\n",
    "        N = len(Q)  \n",
    "        M = len(O)  \n",
    "        alphas = np.zeros((N, M))  \n",
    "        T = M  \n",
    "        for t in range(T):  \n",
    "            indexOfO = V.index(O[t])  \n",
    "            for i in range(N):\n",
    "                if t == 0: \n",
    "                    alphas[i][t] = PI[t][i] * B[i][indexOfO]  \n",
    "                    print('alpha1(%d)=p%db%db(o1)=%f' % (i, i, i, alphas[i][t]))\n",
    "                else:\n",
    "                    alphas[i][t] = np.dot([alpha[t - 1] for alpha in alphas], [a[i] for a in A]) * B[i][\n",
    "                        indexOfO]  \n",
    "                    print('alpha%d(%d)=[sigma alpha%d(i)ai%d]b%d(o%d)=%f' % (t, i, t - 1, i, i, t, alphas[i][t]))\n",
    "        P = np.sum([alpha[M - 1] for alpha in alphas]) \n",
    "\n",
    "    def backward(self, Q, V, A, B, O, PI):  \n",
    "        N = len(Q)  \n",
    "        M = len(O) \n",
    "        betas = np.ones((N, M))  \n",
    "        for i in range(N):\n",
    "            print('beta%d(%d)=1' % (M, i))\n",
    "        for t in range(M - 2, -1, -1):\n",
    "            indexOfO = V.index(O[t + 1])  \n",
    "            for i in range(N):\n",
    "                betas[i][t] = np.dot(np.multiply(A[i], [b[indexOfO] for b in B]), [beta[t + 1] for beta in betas])\n",
    "                realT = t + 1\n",
    "                realI = i + 1\n",
    "                print('beta%d(%d)=[sigma a%djbj(o%d)]beta%d(j)=(' % (realT, realI, realI, realT + 1, realT + 1),\n",
    "                      end='')\n",
    "                for j in range(N):\n",
    "                    print(\"%.2f*%.2f*%.2f+\" % (A[i][j], B[j][indexOfO], betas[j][t + 1]), end='')\n",
    "                print(\"0)=%.3f\" % betas[i][t])\n",
    "        indexOfO = V.index(O[0])\n",
    "        P = np.dot(np.multiply(PI, [b[indexOfO] for b in B]), [beta[0] for beta in betas])\n",
    "        print(\"P(O|lambda)=\", end=\"\")\n",
    "        for i in range(N):\n",
    "            print(\"%.1f*%.1f*%.5f+\" % (PI[0][i], B[i][indexOfO], betas[i][0]), end=\"\")\n",
    "        print(\"0=%f\" % P)\n",
    "\n",
    "    def viterbi(self, Q, V, A, B, O, PI):\n",
    "        N = len(Q)  \n",
    "        M = len(O)  \n",
    "        deltas = np.zeros((N, M))\n",
    "        psis = np.zeros((N, M))\n",
    "        I = np.zeros((1, M))\n",
    "        for t in range(M):\n",
    "            realT = t+1\n",
    "            indexOfO = V.index(O[t]) \n",
    "            for i in range(N):\n",
    "                realI = i+1\n",
    "                if t == 0:\n",
    "                    deltas[i][t] = PI[0][i] * B[i][indexOfO]\n",
    "                    psis[i][t] = 0\n",
    "                    print('delta1(%d)=pi%d * b%d(o1)=%.2f * %.2f=%.2f'%(realI, realI, realI, PI[0][i], B[i][indexOfO], deltas[i][t]))\n",
    "                    print('psis1(%d)=0' % (realI))\n",
    "                else:\n",
    "                    deltas[i][t] = np.max(np.multiply([delta[t-1] for delta in deltas], [a[i] for a in A])) * B[i][indexOfO]\n",
    "                    print('delta%d(%d)=max[delta%d(j)aj%d]b%d(o%d)=%.2f*%.2f=%.5f'%(realT, realI, realT-1, realI, realI, realT, np.max(np.multiply([delta[t-1] for delta in deltas], [a[i] for a in A])), B[i][indexOfO], deltas[i][t]))\n",
    "                    psis[i][t] = np.argmax(np.multiply([delta[t-1] for delta in deltas], [a[i] for a in A]))\n",
    "                    print('psis%d(%d)=argmax[delta%d(j)aj%d]=%d' % (realT, realI, realT-1, realI, psis[i][t]))\n",
    "        print(deltas)\n",
    "        print(psis)\n",
    "        I[0][M-1] = np.argmax([delta[M-1] for delta in deltas])\n",
    "        print('i%d=argmax[deltaT(i)]=%d' % (M, I[0][M-1]+1))\n",
    "        for t in range(M-2, -1, -1):\n",
    "            I[0][t] = psis[int(I[0][t+1])][t+1]\n",
    "            print('i%d=psis%d(i%d)=%d' % (t+1, t+2, t+2, I[0][t]+1))\n",
    "        print(I)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题10.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = [1, 2, 3]\n",
    "V = ['红', '白']\n",
    "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]\n",
    "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n",
    "O = ['红', '白', '红', '白'] \n",
    "PI = [[0.2, 0.4, 0.4]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HMM = HiddenMarkov()\n",
    "HMM.viterbi(Q, V, A, B, O, PI)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 习题10.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = [1, 2, 3]\n",
    "V = ['红', '白']\n",
    "A = [[0.5, 0.2, 0.3], [0.3, 0.5, 0.2], [0.2, 0.3, 0.5]]\n",
    "B = [[0.5, 0.5], [0.4, 0.6], [0.7, 0.3]]\n",
    "O = ['红', '白', '红', '红', '白', '红', '白', '白']\n",
    "PI = [[0.2, 0.3, 0.5]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HMM.forward(Q, V, A, B, O, PI)\n",
    "HMM.backward(Q, V, A, B, O, PI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
